Fecha de publicación

26 de noviembre de 2024

Objetivo del manual

  • Entender como se construyen los árboles de decisión

  • Familiarizarse con los principales métodos de regresión y clasificación basados en árboles

  • Aprender a aplicar estos estos métodos en R

Paquetes a utilizar en este manual:

Código
# instalar/cargar paquetes
sketchy::load_packages(
  c("ggplot2", 
    "viridis", 
    "caret",
    "ISLR",
    "rpart",
    "rpart.plot",
    "tree"
   )
  )
Loading required package: caret
Loading required package: lattice

Los métodos basados en árboles estratifican o subdividen el espacio predictor en en una serie de regiones simples. Dado que el conjunto de reglas de división utilizadas para segmentar el espacio predictor puede representarse en un árbol, este tipo de enfoques se conocen como métodos de árboles de decisión. Los árboles de decisión se utilizan tanto para clasificación como para regresión. Su ventaja radica en la simplicidad de interpretación y en su capacidad para capturar interacciones no lineales entre predictores.

1 Árboles de regresión

Usaremos el conjunto de datos Hitters para predecir el salario de un jugador de béisbol basado en los años (“Years”, el número de años que ha jugado en las grandes ligas) y “Hits” (el número de hits que realizó en el año anterior). Primero eliminamos las observaciones que tienen valores de salario faltantes (“Salary”) y aplicamos una transformación logarítmica al salario para que su distribución tenga una forma más típica de campana.

La siguiente figura muestra un árbol de regresión ajustado a estos datos. Consiste en una serie de reglas de división, comenzando desde la parte superior del árbol. La división superior asigna las observaciones con años < 4.5 a la rama izquierda. El salario predicho para estos jugadores se da por el valor promedio de respuesta para los jugadores en el conjunto de datos con años < 4.5. Para tales jugadores, el salario logarítmico medio es 5.107, por lo que hacemos una predicción de e^5.107 miles de dólares, es decir, $165,174, para estos jugadores. Los jugadores con Años >= 4.5 son asignados a la rama derecha, y luego ese grupo se subdivide aún más por “Hits”.

Código
# cargar datos
data(Hitters)

# Eliminar observaciones con valores faltantes en Salary
Hitters <- na.omit(Hitters)

# Transformar la variable Salary al logaritmo natural
Hitters$LogSalary <- log(Hitters$Salary)

# Ajustar un árbol de regresión con un número limitado de nodos terminales
set.seed(42) # Para reproducibilidad
modelo_arbol <- tree(LogSalary ~ Years + Hits,
                   data = Hitters)

# Podar el árbol para limitarlo a 3 nodos terminales (si es necesario)
modelo_podado <- prune.tree(modelo_arbol, best = 3)

# Graficar el árbol
plot(modelo_podado)
text(modelo_podado, pretty = 0) # Agrega etiquetas legibles

En general, el árbol divide a los jugadores en tres regiones del espacio de predictores: jugadores que han jugado durante cuatro años o menos, jugadores que han jugado durante cinco años o más y que hicieron menos de 118 hits el año pasado, y jugadores que han jugado durante cinco años o más y que hicieron al menos 118 hits el año pasado. Estas tres regiones se pueden escribir como R1 = {X | Años < 4.5}, R2 = {X | Años >= 4.5, Hits < 117.5}, y R3 = {X | Años >= 4.5, Hits >= 117.5}. La Figura 8.2 ilustra las regiones como función de los Años y Hits. Los salarios predichos para estos tres grupos son $1,000 × e^5.107 = $165,174, $1,000 × e^5.999 = $402,834, y $1,000 × e^6.740 = $845,346 respectivamente.

Siguiendo con la analogía del árbol, las regiones R1, R2 y R3 se conocen como nodos terminales o hojas del árbol. Los puntos a lo largo del árbol donde se divide el espacio del predictor se denominan nodos internos. En el árbol graficado mas arriba los dos nodos internos están indicados por el texto “Years < 4.5” y “Hits < 117.5”. Nos referimos a los segmentos del árbol que conectan los nodos como ramas.

Podríamos interpretar el árbol de regresión mostrado mas arriba de la siguiente manera: Los años son el factor más importante para determinar el salario, y los jugadores con menos experiencia ganan salarios más bajos que los jugadores más experimentados. Dado que un jugador tiene menos experiencia, parece que el número de hits que realizó en el año anterior juega un papel poco relevante en su salario. Pero entre los jugadores que han estado en las grandes ligas durante cinco años o más, el número de hits realizados en el año anterior sí afecta al salario, y los jugadores que hicieron más hits el año pasado tienden a tener salarios más altos. El árbol de regresión mostrado arriba es una sobre-simplificación de la verdadera relación entre hits, años y salario. Sin embargo, tiene ventajas sobre otros tipos de modelos de regresión ya que es más fácil de interpretar y tiene una representación gráfica atractiva.

El árbol que ajustamos a estos datos en realidad es mucho mas complejo y contiene mas divisiones de los datos en mas estratos. Los paquetes rpart y rpart.plot nos permite ajustra y visualizar estos modelos de una forma muy amigable:

Código
# Ajustar un árbol de regresión
arbol_regresion <- rpart(LogSalary ~ Years + Hits, data = Hitters)

# Visualizar el árbol
rpart.plot(arbol_regresion, extra = 101)

La complejidad del árbol la podemos controlar mediante el parámetro cp. Un valor más alto de cp resulta en un árbol más simple (con menos divisiones o reglas), mientras que un valor más bajo permite árboles más grandes y complejos. El objetivo es encontrar un equilibrio entre un árbol que sea suficientemente complejo para capturar patrones importantes en los datos, pero no tan complejo que incurra en sobreajuste. Por ejemplo este es un árbol con un valor de cp de 0.05:

Código
arbol_cp.05 <- rpart(LogSalary ~ Years + Hits,
                     data = Hitters,
                     control = rpart.control(cp = 0.05))

rpart.plot(arbol_cp.05, extra = 101)

Este en cambio tiene un valor de cp de 0.001:

Código
arbol_cp.001 <- rpart(LogSalary ~ Years + Hits, data = Hitters, control = rpart.control(cp = 0.001))

rpart.plot(arbol_cp.001, extra = 101)

1.1 Validación cruzada

Afortunadamente la librería caret nos permite realizar validación cruzada para encontrar el mejor valor de cp para nuestro modelo. Podemos utilzar cualquiera de los métodos vistos en el manual de “sobreajuste y entrenamiento de modelos”. En este caso usamos el método de “dejar uno afuera” (LOOCV) para optimizar el valor de cp:

Código
# Configuración de validación cruzada
set.seed(42) # Para reproducibilidad
# Validación cruzada
train_control <- trainControl(method = "LOOCV")

# Entrenar el modelo de árbol usando caret
modelo_arbol <- train(
  LogSalary ~ Years + Hits,
  data = Hitters,
  method = "rpart", # Árbol de decisión
  trControl = train_control,
  tuneLength = 10 # Número de combinaciones de parámetros a probar
)         

# Resumen del modelo ajustado
print(modelo_arbol)
CART 

263 samples
  2 predictor

No pre-processing
Resampling: Leave-One-Out Cross-Validation 
Summary of sample sizes: 262, 262, 262, 262, 262, 262, ... 
Resampling results across tuning parameters:

  cp         RMSE     Rsquared  MAE    
  0.0042120  0.60271  0.547947  0.43523
  0.0046796  0.59658  0.555573  0.42811
  0.0085782  0.59364  0.557312  0.42568
  0.0096474  0.59458  0.554649  0.42884
  0.0110721  0.59198  0.558190  0.42732
  0.0169020  0.58909  0.562176  0.43504
  0.0183127  0.59250  0.555813  0.44211
  0.0444602  0.62855  0.501931  0.47958
  0.1145455  0.71442  0.361503  0.58884
  0.4445745  0.97762  0.016381  0.87625

RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.016902.

El mejor valor de cp encontrado por la validación cruzada es 0.00421 y un valor de RMSE de 0.60271. Podemos extraer y visualizar el árbol final asi:

Código
# extraer mejor modelo
mejor_modelo <- modelo_arbol$finalModel

# Visualizar el árbol final
rpart.plot(mejor_modelo, extra = 101)

1.2 Comparación con modelos lineales

Con el siguiente gráfico podemos comparar el comportamiento de un modelo lineal con el de un árbol de regresión, usando un ejemplo de clasificación en dos dimensiones. En este ejemplo en el que la verdadera frontera de decisión es lineal, y está indicada por las regiones sombreadas. En la fila superior se ilustra como un enfoque clásico que asume una frontera lineal (izquierda) superará a un árbol de decisión que realiza divisiones paralelas a los ejes (derecha). En la fila inferior la verdadera frontera de decisión es no lineal. En este caso, un modelo lineal no puede capturar la verdadera frontera de decisión (izquierda), mientras que un árbol de decisión tiene éxito (derecha).

Tomado de Gareth et al 2013

2 Árboles de clasificación

Al igual que los modelos de regresión con árboles de decisión, los árboles de clasificación dividen el espacio predictor en una serie de regiones simples utilizando reglas de decisión, con el objetivo de predecir una clase o categoría como resultado. Este enfoque es similar al de los árboles de regresión, pero el criterio de división optimiza una métrica asociada con la clasificación, como la entropía o el índice de Gini.

Usaremos el conjunto de datos Carseats para predecir si las ventas son altas (High = “Yes”) o bajas (High = “No”) en función de varias características del conjunto de datos, como precio (Price) o publicidad (Advertising).

Primero preprocesamos los datos para crear la variable objetivo binaria:

Código
# Cargar datos
data(Carseats)

# Crear variable de respuesta binaria
Carseats$High <- ifelse(Carseats$Sales > 8, "Yes", "No")
Carseats$High <- factor(Carseats$High)  # Convertir a factor

# Eliminar la variable 'Sales' para evitar colinealidad
Carseats$Sales <- NULL

2.1 Construcción de un Árbol de Clasificación

Ajustaremos un árbol de clasificación simple utilizando la librería rpart y visualizaremos el árbol generado.

Código
# Ajustar un árbol de clasificación
arbol_clasificacion <- rpart(High ~ Price + Advertising + ShelveLoc + Age, 
                             data = Carseats, 
                             method = "class")

# Visualizar el árbol
rpart.plot(arbol_clasificacion, extra = 104, fallen.leaves = TRUE, shadow.col = "gray")

El árbol resultante muestra cómo los datos se dividen en regiones basadas en los predictores. Por ejemplo, el predictor más importante puede ser Price, donde precios más bajos están asociados con mayores ventas.

2.1.1 Interpretación

En el gráfico del árbol: - Los nodos terminales indican la clase predicha (Yes o No) y el porcentaje de datos que pertenecen a esa clase. - Las divisiones están basadas en reglas como Price < 92.5.

2.2 Podado del Árbol

A menudo, el árbol inicial puede ser demasiado grande, lo que lleva a sobreajuste. Podamos el árbol para limitar su complejidad:

Código
# Podar el árbol usando el valor óptimo de cp
set.seed(42)
arbol_podado <- prune(arbol_clasificacion, cp = 0.02)

# Visualizar el árbol podado
rpart.plot(arbol_podado, extra = 104)

Al podar el árbol, reducimos su complejidad, haciendo que sea más interpretable y menos propenso al sobreajuste.

2.3 Evaluación del Árbol de Clasificación

Para evaluar el desempeño del modelo, dividiremos los datos en conjuntos de entrenamiento y prueba y generaremos una matriz de confusión:

Código
# Dividir en conjunto de entrenamiento y prueba
set.seed(42)
train_index <- sample(seq_len(nrow(Carseats)), size = 0.7 * nrow(Carseats))
train_data <- Carseats[train_index, ]
test_data <- Carseats[-train_index, ]

# Ajustar árbol con datos de entrenamiento
arbol_train <- rpart(High ~ Price + Advertising + ShelveLoc + Age, 
                     data = train_data, 
                     method = "class")

# Predicciones en el conjunto de prueba
predicciones <- predict(arbol_train, test_data, type = "class")

# Matriz de confusión
matriz_confusion <- table(test_data$High, predicciones)
print(matriz_confusion)
     predicciones
      No Yes
  No  64  10
  Yes 16  30

La matriz de confusión muestra el desempeño del modelo en términos de precisión y tasa de error. Por ejemplo:

  • Verdaderos positivos (Yes): matriz_confusion["Yes", "Yes"]
  • Falsos positivos (Yes): matriz_confusion["No", "Yes"]

Podemos calcular métricas adicionales como la precisión o el F1-score si es necesario.

2.4 Validación Cruzada

Para seleccionar el mejor valor de cp, usamos validación cruzada con la librería caret:

Código
# Configuración de validación cruzada
set.seed(42)
control <- trainControl(method = "cv", number = 10)

# Ajustar modelo con validación cruzada
modelo_cv <- train(
  High ~ Price + Advertising + ShelveLoc + Age, 
  data = train_data,
  method = "rpart",
  trControl = control,
  tuneLength = 10
)

# Mostrar los resultados
print(modelo_cv)
CART 

280 samples
  4 predictor
  2 classes: 'No', 'Yes' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 252, 252, 252, 253, 251, 253, ... 
Resampling results across tuning parameters:

  cp        Accuracy  Kappa   
  0.000000  0.70293   0.381306
  0.028249  0.70599   0.381204
  0.056497  0.67383   0.325415
  0.084746  0.66680   0.291874
  0.112994  0.66668   0.274257
  0.141243  0.66668   0.274257
  0.169492  0.66668   0.274257
  0.197740  0.66668   0.274257
  0.225989  0.66668   0.274257
  0.254237  0.60375   0.089011

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.028249.
Código
# Visualizar el árbol final
rpart.plot(modelo_cv$finalModel, extra = 104)

El valor óptimo de cp es 0.02825. El árbol final representa la mejor combinación de simplicidad y precisión según la validación cruzada.

2.5 Comparación con Modelos Lineales

Al igual que con los árboles de regresión, podemos comparar el desempeño de un árbol de clasificación con un modelo lineal o logístico:

Código
# Ajustar un modelo logístico
modelo_logistico <- glm(High ~ Price + Advertising + ShelveLoc + Age, 
                        data = train_data, 
                        family = "binomial")

# Predicciones
pred_log <- predict(modelo_logistico, test_data, type = "response")
pred_log_clas <- ifelse(pred_log > 0.5, "Yes", "No")

# Matriz de confusión para el modelo logístico
matriz_conf_log <- table(test_data$High, pred_log_clas)
print(matriz_conf_log)
     pred_log_clas
      No Yes
  No  64  10
  Yes 11  35

El modelo logístico puede mostrar mayor o menor precisión dependiendo de la relación entre las variables predictoras y la variable respuesta.


2.5.1 Ejercicio Propuesto

  1. Ajusta un árbol de clasificación utilizando otras variables predictoras (por ejemplo, Income o Population).
  2. Realiza validación cruzada con un enfoque distinto (por ejemplo, validación LOOCV).
  3. Compara los resultados con un modelo K-Vecinos más Cercanos (kNN).

Con este esquema puedes cubrir los conceptos clave y realizar análisis detallados para árboles de clasificación. ¿Te gustaría que amplíe alguna sección o incluya más gráficos?

2.6 Ejercicio 1


Referencias

Gareth, J., Daniela, W., Trevor, H., & Robert, T. (2013). An introduction to statistical learning: with applications in R. Spinger.

Información de la sesión

R version 4.3.2 (2023-10-31)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.2 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: America/Bogota
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] caret_6.0-94      lattice_0.20-45   tree_1.0-43       rpart.plot_3.1.2 
 [5] rpart_4.1.16      ISLR_1.4          viridis_0.6.5     viridisLite_0.4.2
 [9] ggplot2_3.5.1     knitr_1.48       

loaded via a namespace (and not attached):
 [1] gtable_0.3.5         xfun_0.47            htmlwidgets_1.6.4   
 [4] recipes_1.0.9        remotes_2.5.0        vctrs_0.6.5         
 [7] tools_4.3.2          generics_0.1.3       stats4_4.3.2        
[10] parallel_4.3.2       proxy_0.4-27         tibble_3.2.1        
[13] fansi_1.0.6          ModelMetrics_1.2.2.2 pkgconfig_2.0.3     
[16] Matrix_1.6-5         data.table_1.14.10   lifecycle_1.0.4     
[19] compiler_4.3.2       stringr_1.5.1        munsell_0.5.1       
[22] codetools_0.2-18     sketchy_1.0.3        htmltools_0.5.8.1   
[25] class_7.3-20         yaml_2.3.10          prodlim_2023.08.28  
[28] pillar_1.9.0         crayon_1.5.3         MASS_7.3-55         
[31] gower_1.0.1          iterators_1.0.14     foreach_1.5.2       
[34] parallelly_1.38.0    nlme_3.1-155         lava_1.7.3          
[37] tidyselect_1.2.1     packrat_0.9.2        digest_0.6.37       
[40] stringi_1.8.4        future_1.34.0        reshape2_1.4.4      
[43] purrr_1.0.2          listenv_0.9.1        dplyr_1.1.4         
[46] splines_4.3.2        fastmap_1.2.0        grid_4.3.2          
[49] colorspace_2.1-1     cli_3.6.3            magrittr_2.0.3      
[52] survival_3.2-13      utf8_1.2.4           e1071_1.7-16        
[55] future.apply_1.11.2  withr_3.0.1          scales_1.3.0        
[58] timechange_0.2.0     lubridate_1.9.3      rmarkdown_2.28      
[61] globals_0.16.3       nnet_7.3-17          timeDate_4032.109   
[64] gridExtra_2.3        evaluate_1.0.0       hardhat_1.3.0       
[67] rlang_1.1.4          Rcpp_1.0.13          glue_1.8.0          
[70] xaringanExtra_0.8.0  pROC_1.18.5          ipred_0.9-14        
[73] rstudioapi_0.16.0    jsonlite_1.8.9       R6_2.5.1            
[76] plyr_1.8.9